{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 06, case 4a: Poisson problem with Dirichlet control\n",
    "\n",
    "In this tutorial we solve the optimal control problem\n",
    "\n",
    "$$\\min J(y, u) = \\frac{1}{2} \\int_{\\Omega} (y - y_d)^2 dx + \\frac{\\alpha}{2} \\int_{\\Gamma_2} u^2 ds$$\n",
    "s.t.\n",
    "$$\\begin{cases}\n",
    "      - \\Delta y = f     & \\text{in } \\Omega\\\\\n",
    "    \\partial_n y = 0     & \\text{on } \\Gamma_1\\\\\n",
    "               y = u     & \\text{on } \\Gamma_2\\\\\n",
    "    \\partial_n y = 0     & \\text{on } \\Gamma_3\\\\\n",
    "               y = 0     & \\text{on } \\Gamma_4\\\\\n",
    "\\end{cases}$$\n",
    "\n",
    "where\n",
    "$$\\begin{align*}\n",
    "& \\Omega               & \\text{unit square}\\\\\n",
    "& \\Gamma_1             & \\text{bottom boundary of the square}\\\\\n",
    "& \\Gamma_2             & \\text{left boundary of the square}\\\\\n",
    "& \\Gamma_3             & \\text{top boundary of the square}\\\\\n",
    "& \\Gamma_4             & \\text{right boundary of the square}\\\\\n",
    "& u \\in L^2(\\Gamma_2)  & \\text{control variable}\\\\\n",
    "& y \\in H^1(\\Omega)    & \\text{state variable}\\\\\n",
    "& \\alpha > 0           & \\text{penalization parameter}\\\\\n",
    "& y_d                  & \\text{desired state}\\\\\n",
    "& f                    & \\text{forcing term}\n",
    "\\end{align*}$$\n",
    "using an adjoint formulation solved by a one shot approach"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mpi4py import MPI\n",
    "from petsc4py import PETSc\n",
    "import ufl\n",
    "from ufl import grad, inner, Measure, replace, SpatialCoordinate, TestFunction, TrialFunction\n",
    "from dolfinx import Constant, DirichletBC, Function, FunctionSpace\n",
    "from dolfinx.fem import locate_dofs_topological, set_bc\n",
    "from dolfinx.io import XDMFFile\n",
    "from dolfinx.plot import create_vtk_topology\n",
    "from multiphenicsx.fem import (apply_lifting, assemble_matrix, assemble_matrix_block, assemble_scalar,\n",
    "                               assemble_vector, assemble_vector_block, BlockVecSubVectorWrapper,\n",
    "                               create_vector_block, DofMapRestriction)\n",
    "import pyvista"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mesh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with XDMFFile(MPI.COMM_WORLD, \"data/square.xdmf\", \"r\") as infile:\n",
    "    mesh = infile.read_mesh()\n",
    "    mesh.topology.create_connectivity_all()\n",
    "    subdomains = infile.read_meshtags(mesh, name=\"subdomains\")\n",
    "    boundaries = infile.read_meshtags(mesh, name=\"boundaries\")\n",
    "boundaries_2 = boundaries.indices[boundaries.values == 2]\n",
    "boundaries_4 = boundaries.indices[boundaries.values == 4]\n",
    "boundaries_24 = boundaries.indices[np.isin(boundaries.values, (2, 4))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define associated measures\n",
    "dx = Measure(\"dx\")(subdomain_data=subdomains)\n",
    "ds = Measure(\"ds\")(subdomain_data=boundaries)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dolfinx_to_pyvista_mesh(mesh):\n",
    "    num_cells = mesh.topology.index_map(mesh.topology.dim).size_local\n",
    "    cell_entities = np.arange(num_cells, dtype=np.int32)\n",
    "    pyvista_cells, cell_types = create_vtk_topology(mesh, mesh.topology.dim, cell_entities)\n",
    "    grid = pyvista.UnstructuredGrid(pyvista_cells, cell_types, mesh.geometry.x)\n",
    "    return grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pyvista_mesh_plot(mesh):\n",
    "    grid = dolfinx_to_pyvista_mesh(mesh)\n",
    "    plotter = pyvista.PlotterITK()\n",
    "    plotter.add_mesh(grid)\n",
    "    plotter.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_mesh_plot(mesh)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Function spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = FunctionSpace(mesh, (\"Lagrange\", 2))\n",
    "U = FunctionSpace(mesh, (\"Lagrange\", 2))\n",
    "L = U.clone()\n",
    "Q = Y.clone()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Restrictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dofs_Y = np.arange(0, Y.dofmap.index_map.size_local + Y.dofmap.index_map.num_ghosts)\n",
    "dofs_U = locate_dofs_topological(U, boundaries.dim, boundaries_2)\n",
    "dofs_L = dofs_U\n",
    "dofs_Q = dofs_Y\n",
    "restriction_Y = DofMapRestriction(Y.dofmap, dofs_Y)\n",
    "restriction_U = DofMapRestriction(U.dofmap, dofs_U)\n",
    "restriction_L = DofMapRestriction(L.dofmap, dofs_L)\n",
    "restriction_Q = DofMapRestriction(Q.dofmap, dofs_Q)\n",
    "restriction = [restriction_Y, restriction_U, restriction_L, restriction_Q]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Trial and test functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(y, u, l, p) = (TrialFunction(Y), TrialFunction(U), TrialFunction(L), TrialFunction(Q))\n",
    "(z, v, m, q) = (TestFunction(Y), TestFunction(U), TestFunction(L), TestFunction(Q))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ### Problem data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 1.e-5\n",
    "y_d = 1.\n",
    "x = SpatialCoordinate(mesh)\n",
    "ff = 10 * ufl.sin(2 * ufl.pi * x[0]) * ufl.sin(2 * ufl.pi * x[1])\n",
    "bc0 = Function(Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimality conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = [[y * z * dx, None, l * z * ds(2), inner(grad(p), grad(z)) * dx],\n",
    "     [None, alpha * u * v * ds(2), - l * v * ds(2), None],\n",
    "     [y * m * ds(2), - u * m * ds(2), None, None],\n",
    "     [inner(grad(y), grad(q)) * dx, None, None, None]]\n",
    "f = [y_d * z * dx,\n",
    "     None,\n",
    "     None,\n",
    "     ff * q * dx]\n",
    "a[3][3] = Constant(mesh, 0.) * p * q * dx\n",
    "f[1] = Constant(mesh, 0.) * v * dx\n",
    "f[2] = Constant(mesh, 0.) * m * dx\n",
    "bdofs_Y_4 = locate_dofs_topological((Y, Y), mesh.topology.dim - 1, boundaries_4)\n",
    "bdofs_Q_24 = locate_dofs_topological((Q, Y), mesh.topology.dim - 1, boundaries_24)\n",
    "bc = [DirichletBC(bc0, bdofs_Y_4, Y),\n",
    "      DirichletBC(bc0, bdofs_Q_24, Q)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(y, u, l, p) = (Function(Y), Function(U), Function(L), Function(Q))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cost functional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "J = 0.5 * inner(y - y_d, y - y_d) * dx + 0.5 * alpha * inner(u, u) * ds(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Uncontrolled functional value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract state forms from the optimality conditions\n",
    "a_state = replace(a[3][0], {q: z})\n",
    "f_state = replace(f[3], {q: z})\n",
    "bdofs_Y_24 = locate_dofs_topological((Y, Y), mesh.topology.dim - 1, boundaries_24)\n",
    "bc_state = [DirichletBC(bc0, bdofs_Y_24, Y)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assemble the linear system for the state\n",
    "A_state = assemble_matrix(a_state, bcs=bc_state)\n",
    "A_state.assemble()\n",
    "F_state = assemble_vector(f_state)\n",
    "apply_lifting(F_state, [a_state], [bc_state])\n",
    "F_state.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)\n",
    "set_bc(F_state, bc_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve\n",
    "ksp = PETSc.KSP()\n",
    "ksp.create(mesh.mpi_comm())\n",
    "ksp.setOperators(A_state)\n",
    "ksp.setType(\"preonly\")\n",
    "ksp.getPC().setType(\"lu\")\n",
    "ksp.getPC().setFactorSolverType(\"mumps\")\n",
    "ksp.setFromOptions()\n",
    "ksp.solve(F_state, y.vector)\n",
    "y.vector.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "J_uncontrolled = mesh.mpi_comm().allreduce(assemble_scalar(J), op=MPI.SUM)\n",
    "print(\"Uncontrolled J =\", J_uncontrolled)\n",
    "assert np.isclose(J_uncontrolled, 0.5038977)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pyvista_scalar_field_plot(mesh, scalar_field, name):\n",
    "    grid = dolfinx_to_pyvista_mesh(mesh)\n",
    "    grid.point_arrays[name] = scalar_field.compute_point_values()\n",
    "    grid.set_active_scalars(name)\n",
    "    plotter = pyvista.PlotterITK()\n",
    "    plotter.add_mesh(grid)\n",
    "    plotter.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, y, \"uncontrolled state\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimal control"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assemble the block linear system for the optimality conditions\n",
    "A = assemble_matrix_block(a, bcs=bc, restriction=(restriction, restriction))\n",
    "A.assemble()\n",
    "F = assemble_vector_block(f, a, bcs=bc, restriction=restriction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve\n",
    "yulp = create_vector_block(f, restriction=restriction)\n",
    "ksp = PETSc.KSP()\n",
    "ksp.create(mesh.mpi_comm())\n",
    "ksp.setOperators(A)\n",
    "ksp.setType(\"preonly\")\n",
    "ksp.getPC().setType(\"lu\")\n",
    "ksp.getPC().setFactorSolverType(\"mumps\")\n",
    "ksp.setFromOptions()\n",
    "ksp.solve(F, yulp)\n",
    "yulp.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the block solution in components\n",
    "with BlockVecSubVectorWrapper(yulp, [Y.dofmap, U.dofmap, L.dofmap, Q.dofmap], restriction) as yulp_wrapper:\n",
    "    for yulp_wrapper_local, component in zip(yulp_wrapper, (y, u, l, p)):\n",
    "        with component.vector.localForm() as component_local:\n",
    "            component_local[:] = yulp_wrapper_local"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "J_controlled = mesh.mpi_comm().allreduce(assemble_scalar(J), op=MPI.SUM)\n",
    "print(\"Optimal J =\", J_controlled)\n",
    "assert np.isclose(J_controlled, 0.1281224)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, y, \"state\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, u, \"control\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, l, \"lambda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, p, \"adjoint\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython"
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
